package org.zjvis.datascience.service.dag;

import com.alibaba.fastjson.JSONObject;

import java.util.Map;
import java.util.concurrent.Callable;

import org.apache.commons.lang3.StringUtils;
import org.zjvis.datascience.common.constant.Constant;
import org.zjvis.datascience.common.dto.TaskInstanceDTO;
import org.zjvis.datascience.common.enums.JobStatus;
import org.zjvis.datascience.common.enums.TaskInstanceStatus;
import org.zjvis.datascience.common.util.RestTemplateUtil;
import org.zjvis.datascience.common.vo.JobStatusVO;
import org.zjvis.datascience.service.TaskInstanceService;

/**
 * @description Spark任务调度器
 * @date 2021-12-10
 */
public class SparkSubmitRunner implements Callable<TaskRunnerResult> {

    private String errorTpl = "{\"status\":500, \"error_msg\":\"%s\"}";
    private int maxRetryTimes = 1;
    private int sleepTime = 20000;
    private RestTemplateUtil restTemplateUtil;
    private TaskInstanceDTO instance;
    private Map<String, TaskInstanceDTO> appIdInstanceMap;
    private TaskInstanceService taskInstanceService;

    public SparkSubmitRunner(RestTemplateUtil restTemplateUtil, TaskInstanceDTO instance,
                             Map appIdInstanceMap, TaskInstanceService taskInstanceService) {
        this.restTemplateUtil = restTemplateUtil;
        this.instance = instance;
        this.appIdInstanceMap = appIdInstanceMap;
        this.taskInstanceService = taskInstanceService;
    }

    @Override
    public TaskRunnerResult call() throws Exception {
        int index = 0;
        TaskRunnerResult result = null;
        if (instance.hasPrecautionaryError()) {
            return new TaskRunnerResult(500, String.format(errorTpl, "error happens when init stage."));
        }
        while (index < maxRetryTimes) {
            try {
                String algoName = instance.getSqlText().split(" ")[0];
                if (algoName.equals("linear")) {
                    result = this.exec();
                } else {
                    result = this.submit();
                }
                if (result.getStatus() == 0) {
                    break;
                }
                ++index;
            } catch (Exception e) {
                ++index;
                Thread.sleep(sleepTime);
                result = new TaskRunnerResult(500,
                        String.format(errorTpl, e.getMessage().replaceAll("\"", "'")));
            }
        }
        return result;
    }

    public TaskRunnerResult submit() {
        String appArgs = instance.getSqlText();
        if (StringUtils.isEmpty(appArgs) || !appArgs.contains(" -f ")) {
            return new TaskRunnerResult(500, String.format(errorTpl, "task not config"));
        }
        //在最后拼上taskIns的id参数，算子需要该参数将输出元数据写入mysql
        appArgs = appArgs + " -id " + instance.getId();
        String json = restTemplateUtil.runJob(appArgs);
        if (StringUtils.isEmpty(json)) {
            return new TaskRunnerResult(500, String.format(errorTpl, "job submit fail"));
        }
        JSONObject jsonObj = JSONObject.parseObject(json);
        String status = jsonObj.getString("status");
        String jobId = jsonObj.getString("jobId");
        instance.setApplicationId(jobId);
        taskInstanceService.update(instance);
        if ("ERROR".equals(status)) {
            JSONObject result = jsonObj.getJSONObject("result");
            String message = result.getString("message");
            return new TaskRunnerResult(500,
                    String.format(errorTpl, message.replaceAll("\"", "'")));
        }
        return new TaskRunnerResult(0, null);
    }

    public TaskRunnerResult exec() {
        //在最后拼上taskIns的id参数，算子需要该参数将输出元数据写入mysql
        String appArgs =
                "-a " + instance.getSqlText() + " -id " + instance.getId() + " -u " + instance
                        .getUserId();
        //TODO
        String applicationId = restTemplateUtil.submitJob(appArgs);
        if (StringUtils.isEmpty(applicationId)) {
            return new TaskRunnerResult(500, String.format(errorTpl, "job submit fail"));
        }
        instance.setApplicationId(applicationId);
        taskInstanceService.update(instance);
        appIdInstanceMap.put(applicationId, instance);
        int queryTimes = 0;
        for (; queryTimes <= maxRetryTimes; ) {
            //每隔1S查询一次任务状态
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }
            //已经成功或失败或停止的退出
            if (TaskInstanceStatus.SUCCESS.toString().equals(instance.getStatus()) ||
                    TaskInstanceStatus.FAIL.toString().equals(instance.getStatus()) ||
                    TaskInstanceStatus.KILLED.toString().equals(instance.getStatus())) {
                return new TaskRunnerResult(0, null);
            }
            //查询spark任务状态
            JobStatusVO jobStatusVO = restTemplateUtil
                    .queryJobStatus(applicationId, instance.getLogInfo());
            //查询失败，下次再查询
            if (null == jobStatusVO || StringUtils.isEmpty(jobStatusVO.getState())) {
                queryTimes++;
                continue;
            }
            //未结束的任务跳过
            if (!JobStatus.jobIsEnd(jobStatusVO.getState())) {
                continue;
            }
            //成功后更新
            if (JobStatus.SUCCEEDED.toString().equals(jobStatusVO.getFinalStatus())) {
                instance.setStatus(TaskInstanceStatus.SUCCESS.toString());
                //失败后更新
            } else if (JobStatus.FAILED.toString().equals(jobStatusVO.getFinalStatus())) {
                instance.setStatus(TaskInstanceStatus.FAIL.toString());
                instance.setLogInfo(String.format(Constant.errorTpl, jobStatusVO.getDiagnostics()
                        .replaceAll("\"", "'")));
                //停止后更新
            } else if (JobStatus.KILLED.toString().equals(jobStatusVO.getFinalStatus())) {
                instance.setStatus(TaskInstanceStatus.KILLED.toString());
                instance.setLogInfo(String.format(Constant.errorTpl, jobStatusVO.getDiagnostics()
                        .replaceAll("\"", "'")));
            }
            instance.setProgress(100);
            taskInstanceService.update(instance);
            appIdInstanceMap.remove(applicationId);
            return new TaskRunnerResult(0, null);
        }
        //查询失败超过3次到这里
        instance.setStatus(TaskInstanceStatus.FAIL.toString());
        instance.setLogInfo(
                String.format(Constant.errorTpl, "queryJobStatus fail more than three times"));
        taskInstanceService.update(instance);
        return new TaskRunnerResult(0, null);
    }
}